import { RefObject } from '@mui/x-internals/types';
import {
  gridColumnLookupSelector,
  GridRowId,
  GridGroupNode,
  GridLeafNode,
  gridRowTreeSelector,
  GRID_ROOT_GROUP_ID,
  gridRowsLookupSelector,
  GridColumnLookup,
} from '@mui/x-data-grid-pro';
import { getVisibleRows } from '@mui/x-data-grid/internals';
import { GridApiPremium, GridPrivateApiPremium } from '../../../models/gridApiPremium';
import { DataGridPremiumProcessedProps } from '../../../models/dataGridPremiumProps';
import {
  GridAggregationFunction,
  GridAggregationLookup,
  GridAggregationPosition,
  GridAggregationRules,
} from './gridAggregationInterfaces';

type AggregatedValues = { aggregatedField: string; values: any[] }[];

export const shouldApplySorting = (
  aggregationRules: GridAggregationRules,
  aggregatedFields: string[],
) => {
  return aggregatedFields.some((field) => aggregationRules[field].aggregationFunction.applySorting);
};

const getGroupAggregatedValue = (
  groupId: GridRowId,
  apiRef: RefObject<GridPrivateApiPremium>,
  aggregationRowsScope: DataGridPremiumProcessedProps['aggregationRowsScope'],
  aggregatedFields: string[],
  aggregationRules: GridAggregationRules,
  position: GridAggregationPosition | null,
  applySorting: boolean,
  valueGetters: Record<string, (row: any) => any>,
  publicApi: GridApiPremium,
  groupAggregatedValuesLookup: Map<GridRowId, AggregatedValues>,
  columnsLookup: GridColumnLookup,
) => {
  const groupAggregationLookup: GridAggregationLookup[GridRowId] = {};
  const aggregatedValues: AggregatedValues = [];
  for (let i = 0; i < aggregatedFields.length; i += 1) {
    aggregatedValues[i] = {
      aggregatedField: aggregatedFields[i],
      values: [],
    };
  }

  const rowTree = gridRowTreeSelector(apiRef);
  const rowLookup = gridRowsLookupSelector(apiRef);
  const isPivotActive = apiRef.current.state.pivoting.active;

  const rowIds = apiRef.current.getRowGroupChildren({
    groupId,
    applySorting,
    directChildrenOnly: true,
    skipAutoGeneratedRows: false,
    applyFiltering: aggregationRowsScope === 'filtered',
  });

  for (let i = 0; i < rowIds.length; i += 1) {
    const rowId = rowIds[i];
    const rowNode = rowTree[rowId];

    if (rowNode.type === 'group') {
      // MERGE EXISTING VALUES FROM THE LOOKUP TABLE
      const childGroupValues = groupAggregatedValuesLookup.get(rowId);
      if (childGroupValues) {
        for (let j = 0; j < aggregatedFields.length; j += 1) {
          aggregatedValues[j].values = aggregatedValues[j].values.concat(
            childGroupValues[j].values,
          );
        }
      }
      continue;
    }

    const row = rowLookup[rowId];
    if (!row) {
      continue;
    }

    for (let j = 0; j < aggregatedFields.length; j += 1) {
      const aggregatedField = aggregatedFields[j];
      const columnAggregationRules = aggregationRules[aggregatedField];

      const aggregationFunction =
        columnAggregationRules.aggregationFunction as GridAggregationFunction;
      const field = aggregatedField;

      let value;
      if (typeof aggregationFunction.getCellValue === 'function') {
        value = aggregationFunction.getCellValue({ field, row });
      } else if (isPivotActive) {
        // Since we know that pivoted fields are flat, we can use the row directly, and save lots of processing time
        value = row[field];
      } else {
        const valueGetter = valueGetters[aggregatedField]!;
        value = valueGetter(row);
      }

      if (value !== undefined) {
        aggregatedValues[j].values.push(value);
      }
    }
  }

  for (let i = 0; i < aggregatedValues.length; i += 1) {
    const { aggregatedField, values } = aggregatedValues[i];
    const aggregationFunction = aggregationRules[aggregatedField]
      .aggregationFunction as GridAggregationFunction;
    const value = aggregationFunction.apply(
      {
        values,
        groupId,
        field: aggregatedField, // Added per user request in https://github.com/mui/mui-x/issues/6995#issuecomment-1327423455
      },
      publicApi,
    );
    const formattedValue = aggregationFunction.valueFormatter
      ? aggregationFunction.valueFormatter(
          value as never,
          rowLookup[groupId],
          columnsLookup[aggregatedField],
          apiRef,
        )
      : undefined;

    // Only add to groupAggregationLookup if position is not null
    if (position !== null) {
      groupAggregationLookup[aggregatedField] = {
        position,
        value,
        formattedValue,
      };
    }
  }

  return { groupAggregationLookup, aggregatedValues };
};

const getGroupAggregatedValueDataSource = (
  groupId: GridRowId,
  apiRef: RefObject<GridPrivateApiPremium>,
  aggregatedFields: string[],
  position: GridAggregationPosition,
) => {
  const groupAggregationLookup: GridAggregationLookup[GridRowId] = {};

  for (let j = 0; j < aggregatedFields.length; j += 1) {
    const aggregatedField = aggregatedFields[j];

    groupAggregationLookup[aggregatedField] = {
      position,
      value: apiRef.current.resolveGroupAggregation?.(groupId, aggregatedField) ?? '',
    };
  }

  return groupAggregationLookup;
};

export const createAggregationLookup = ({
  apiRef,
  aggregationRules,
  aggregatedFields,
  aggregationRowsScope,
  getAggregationPosition,
  isDataSource,
  applySorting = false,
}: {
  apiRef: RefObject<GridPrivateApiPremium>;
  aggregationRules: GridAggregationRules;
  aggregatedFields: string[];
  aggregationRowsScope: DataGridPremiumProcessedProps['aggregationRowsScope'];
  getAggregationPosition: DataGridPremiumProcessedProps['getAggregationPosition'];
  isDataSource: boolean;
  applySorting: boolean;
}): GridAggregationLookup => {
  if (aggregatedFields.length === 0) {
    return {};
  }

  const columnsLookup = gridColumnLookupSelector(apiRef);
  const valueGetters = {} as Record<string, (row: any) => any>;
  for (let i = 0; i < aggregatedFields.length; i += 1) {
    const field = aggregatedFields[i];
    const column = columnsLookup[field];
    const valueGetter = (row: any) => apiRef.current.getRowValue(row, column);
    valueGetters[field] = valueGetter;
  }

  const aggregationLookup: GridAggregationLookup = {};
  const rowTree = gridRowTreeSelector(apiRef);

  const groupAggregatedValuesLookup = new Map<GridRowId, AggregatedValues>();

  const { rowIdToIndexMap } = getVisibleRows(apiRef);
  const createGroupAggregationLookup = (groupNode: GridGroupNode) => {
    let children = groupNode.children;
    if (applySorting) {
      children = children.toSorted((a, b) => rowIdToIndexMap.get(a)! - rowIdToIndexMap.get(b)!);
    }
    for (let i = 0; i < children.length; i += 1) {
      const childId = children[i];
      const childNode = rowTree[childId] as GridGroupNode | GridLeafNode;

      if (childNode.type === 'group') {
        createGroupAggregationLookup(childNode);
      }
    }

    const position = getAggregationPosition(groupNode);

    if (isDataSource) {
      if (position !== null) {
        aggregationLookup[groupNode.id] = getGroupAggregatedValueDataSource(
          groupNode.id,
          apiRef,
          aggregatedFields,
          position,
        );
      }
    } else if (groupNode.children.length) {
      const result = getGroupAggregatedValue(
        groupNode.id,
        apiRef,
        aggregationRowsScope,
        aggregatedFields,
        aggregationRules,
        position,
        applySorting,
        valueGetters,
        apiRef.current,
        groupAggregatedValuesLookup,
        columnsLookup,
      );
      // Always populate groupAggregatedValuesLookup for groups with children
      // This ensures parent groups can access child aggregated values even when position is null
      groupAggregatedValuesLookup.set(groupNode.id, result.aggregatedValues);
      // Only set aggregationLookup if position is not null (meaning aggregation should be displayed)
      if (position !== null) {
        aggregationLookup[groupNode.id] = result.groupAggregationLookup;
      }
    }
  };

  createGroupAggregationLookup(rowTree[GRID_ROOT_GROUP_ID] as GridGroupNode);

  return aggregationLookup;
};
